# -*- coding: utf-8 -*-
"""DQN.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1PZKHtnECy-6vg-zGzEQY-U7ufapqUczO
"""

import numpy as np
import random
from collections import deque
import tensorflow as tf
from tensorflow import keras
import gymnasium as gym

# --- 1. Define the DQN Agent ---
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        self.memory = deque(maxlen=2000)  # Experience replay memory
        self.gamma = 0.95    # Discount factor
        self.epsilon = 1.0   # Exploration-exploitation trade-off
        self.epsilon_decay = 0.995
        self.epsilon_min = 0.01
        self.learning_rate = 0.01
        self.batch_size = 64
        self.train_start = 1000 # Start training after this many experiences

        # Build main model and target model
        self.model = self._build_model()
        self.target_model = self._build_model()
        self.update_target_model() # Initialize target model to be same as main model

    def _build_model(self):
        # Neural Net for Deep-Q learning Model
        model = keras.Sequential()
        model.add(keras.layers.Input(shape=(self.state_size,)))
        model.add(keras.layers.Dense(24, activation='relu'))
        model.add(keras.layers.Dense(24, activation='relu'))
        model.add(keras.layers.Dense(self.action_size, activation='linear'))
        model.compile(loss='mse', optimizer=keras.optimizers.Adam(learning_rate=self.learning_rate))
        return model

    def update_target_model(self):
        # Copy weights from main model to target model
        self.target_model.set_weights(self.model.get_weights())

    def remember(self, state, action, reward, next_state, done):
        # Store experience in replay memory
        self.memory.append((state, action, reward, next_state, done))

    def act(self, state):
        # Epsilon-greedy policy for action selection
        if np.random.rand() <= self.epsilon:
            return random.randrange(self.action_size)
        q_values = self.model.predict(state, verbose=0)
        return np.argmax(q_values[0])

    def replay(self):
        if len(self.memory) < self.train_start:
            return

        # Sample a batch of experiences from memory
        minibatch = random.sample(self.memory, self.batch_size)

        states = np.array([t[0][0] for t in minibatch])
        actions = np.array([t[1] for t in minibatch])
        rewards = np.array([t[2] for t in minibatch])
        next_states = np.array([t[3][0] for t in minibatch])
        dones = np.array([t[4] for t in minibatch])

        # Predict Q-values for current states using the main model
        current_q_values = self.model.predict(states, verbose=0)
        # Predict Q-values for next states using the target model
        target_q_values_next_state = self.target_model.predict(next_states, verbose=0)

        # Initialize target Q-values (will be updated)
        target_q_values = np.copy(current_q_values)

        for i in range(self.batch_size):
            if dones[i]:
                # If episode ends, target Q-value is just the reward
                target_q_values[i][actions[i]] = rewards[i]
            else:
                # Bellman equation: Q(s,a) = r + gamma * max(Q(s',a'))
                target_q_values[i][actions[i]] = rewards[i] + self.gamma * np.amax(target_q_values_next_state[i])

        # Train the main model
        self.model.fit(states, target_q_values, epochs=1, verbose=0)

        # Decay epsilon
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

# --- 2. Main Training Loop ---
if __name__ == '__main__':
    # Create the CartPole-v1 environment
    env = gym.make('CartPole-v1')
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.n

    # Initialize the DQN agent
    agent = DQNAgent(state_size, action_size)

    # Set up parameters for training
    episodes = 500
    # Update target model every C episodes
    target_update_frequency = 10

    # Stores scores for plotting/tracking progress
    scores = deque(maxlen=100) # Keep track of last 100 scores for average

    for e in range(episodes):
        if average_score >= 195:
          break
        state, info = env.reset()
        state = np.reshape(state, [1, state_size]) # Reshape state for neural network input
        done = False
        score = 0
        time_step = 0

        while not done:
            # env.render() # Uncomment to see the environment visually

            action = agent.act(state) # Agent chooses an action
            next_state, reward, done, truncated, info = env.step(action)
            next_state = np.reshape(next_state, [1, state_size])

            # Apply reward shaping for CartPole:
            # Reward is 1 for every step pole is upright.
            # If done, and pole fell, apply negative reward.
            if done and time_step < 500: # Penalize only if it fails before 500 steps
                reward = -10 # Penalize for failing
            else:
                reward = 1 # Reward for staying alive (including if it reaches 500 steps)


            agent.remember(state, action, reward, next_state, done) # Store the experience
            state = next_state
            score += reward
            time_step += 1

            if done:
                scores.append(time_step) # Store time_step as score for CartPole
                average_score = np.mean(scores)
                print(f"Episode: {e+1}/{episodes}, Score: {time_step}, Average Score: {average_score:.2f}, Epsilon: {agent.epsilon:.2f}, Time steps: {time_step}")


                # Update target model
                if e % target_update_frequency == 0:
                    agent.update_target_model()

                break # End of episode

        # Train the agent after each episode (or after a certain number of experiences)
        agent.replay()

    env.close()
    print("Training finished!")
    if average_score >= 195: # CartPole-v1 solved when average score is 195 over 100 episodes
        print("CartPole-v1 solved!")
    else:
        print("CartPole-v1 not solved, try increasing episodes or tuning hyperparameters.")